{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VAE for MNIST clustering and generation\n",
    "\n",
    "The goal of this notebook is to explore some recent works dealing with variational auto-encoder (VAE).\n",
    "\n",
    "We will use MNIST dataset and a basic VAE architecture. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import save_image\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from sklearn.metrics.cluster import normalized_mutual_info_score\n",
    "\n",
    "def show(img):\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Device configuration\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Create a directory if not exists\n",
    "sample_dir = 'samples'\n",
    "if not os.path.exists(sample_dir):\n",
    "    os.makedirs(sample_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "#to be modified\n",
    "data_dir = '/home/mlelarge/data'\n",
    "# MNIST dataset\n",
    "dataset = torchvision.datasets.MNIST(root=data_dir,\n",
    "                                     train=True,\n",
    "                                     transform=transforms.ToTensor(),\n",
    "                                     download=True)\n",
    "\n",
    "# Data loader\n",
    "data_loader = torch.utils.data.DataLoader(dataset=dataset,\n",
    "                                          batch_size=batch_size, \n",
    "                                          shuffle=True)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    torchvision.datasets.MNIST(data_dir, train=False, download=True, transform=transforms.ToTensor()),\n",
    "    batch_size=10, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Variational Autoencoders\n",
    "\n",
    "Consider a latent variable model with a data variable $x\\in \\mathcal{X}$ and a latent variable $z\\in \\mathcal{Z}$, $p(z,x) = p(z)p_\\theta(x|z)$. Given the data $x_1,\\dots, x_n$, we want to train the model by maximizing the marginal log-likelihood:\n",
    "\\begin{eqnarray*}\n",
    "\\mathcal{L} = \\mathbf{E}_{p_d(x)}\\left[\\log p_\\theta(x)\\right]=\\mathbf{E}_{p_d(x)}\\left[\\log \\int_{\\mathcal{Z}}p_{\\theta}(x|z)p(z)dz\\right],\n",
    "  \\end{eqnarray*}\n",
    "  where $p_d$ denotes the empirical distribution of $X$: $p_d(x) =\\frac{1}{n}\\sum_{i=1}^n \\delta_{x_i}(x)$.\n",
    "\n",
    " To avoid the (often) difficult computation of the integral above, the idea behind variational methods is to instea maximize a lower bound to the log-likelihood:\n",
    "  \\begin{eqnarray*}\n",
    "\\mathcal{L} \\geq L(p_\\theta(x|z),q(z|x)) =\\mathbf{E}_{p_d(x)}\\left[\\mathbf{E}_{q(z|x)}\\left[\\log p_\\theta(x|z)\\right]-\\mathrm{KL}\\left( q(z|x)||p(z)\\right)\\right].\n",
    "  \\end{eqnarray*}\n",
    "  Any choice of $q(z|x)$ gives a valid lower bound. Variational autoencoders replace the variational posterior $q(z|x)$ by an inference network $q_{\\phi}(z|x)$ that is trained together with $p_{\\theta}(x|z)$ to jointly maximize $L(p_\\theta,q_\\phi)$. The variational posterior $q_{\\phi}(z|x)$ is also called the encoder and the generative model $p_{\\theta}(x|z)$, the decoder or generator.\n",
    "\n",
    "The first term $\\mathbf{E}_{q(z|x)}\\left[\\log p_\\theta(x|z)\\right]$ is the negative reconstruction error. Indeed under a gaussian assumption i.e. $p_{\\theta}(x|z) = \\mathcal{N}(\\mu_{\\theta}(z), 1)$ the term $\\log p_\\theta(x|z)$ reduced to $\\propto \\|x-\\mu_\\theta(z)\\|^2$, which is often used in practice. The term $\\mathrm{KL}\\left( q(z|x)||p(z)\\right)$ can be seen as a regularization term, where the variational posterior $q_\\phi(z|x)$ should be matched to the prior $p(z)= \\mathcal{N}(0,1)$.\n",
    "\n",
    "Variational Autoencoders were introduced by [Kingma and Welling](https://arxiv.org/abs/1312.6114), see also [Doersch](https://arxiv.org/abs/1606.05908) for a tutorial.\n",
    "\n",
    "There are vairous examples of VAE in pytorch available [here](https://github.com/pytorch/examples/tree/master/vae) or [here](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/variational_autoencoder/main.py#L38-L65). The code below is taken from this last source."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyper-parameters\n",
    "image_size = 784\n",
    "h_dim = 400\n",
    "z_dim = 20\n",
    "num_epochs = 15\n",
    "learning_rate = 1e-3\n",
    "\n",
    "# VAE model\n",
    "class VAE(nn.Module):\n",
    "    def __init__(self, image_size=784, h_dim=400, z_dim=20):\n",
    "        super(VAE, self).__init__()\n",
    "        self.fc1 = nn.Linear(image_size, h_dim)\n",
    "        self.fc2 = nn.Linear(h_dim, z_dim)\n",
    "        self.fc3 = nn.Linear(h_dim, z_dim)\n",
    "        self.fc4 = nn.Linear(z_dim, h_dim)\n",
    "        self.fc5 = nn.Linear(h_dim, image_size)\n",
    "        \n",
    "    def encode(self, x):\n",
    "        h = F.relu(self.fc1(x))\n",
    "        return self.fc2(h), self.fc3(h)\n",
    "    \n",
    "    def reparameterize(self, mu, log_var):\n",
    "        std = torch.exp(log_var/2)\n",
    "        eps = torch.randn_like(std)\n",
    "        return mu + eps * std\n",
    "\n",
    "    def decode(self, z):\n",
    "        h = F.relu(self.fc4(z))\n",
    "        return torch.sigmoid(self.fc5(h))\n",
    "    \n",
    "    def forward(self, x):\n",
    "        mu, log_var = self.encode(x)\n",
    "        z = self.reparameterize(mu, log_var)\n",
    "        x_reconst = self.decode(z)\n",
    "        return x_reconst, mu, log_var\n",
    "\n",
    "model = VAE().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here for the loss, instead of MSE for the reconstruction loss, we take BCE. The code below is still from the pytorch tutorial (with minor modifications to avoid warnings!)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start training\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (x, _) in enumerate(data_loader):\n",
    "        # Forward pass\n",
    "        x = x.to(device).view(-1, image_size)\n",
    "        x_reconst, mu, log_var = model(x)\n",
    "        \n",
    "        # Compute reconstruction loss and kl divergence\n",
    "        # For KL divergence, see Appendix B in VAE paper\n",
    "        reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction='sum')\n",
    "        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())\n",
    "        \n",
    "        # Backprop and optimize\n",
    "        loss = reconst_loss + kl_div\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if (i+1) % 10 == 0:\n",
    "            print (\"Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}\" \n",
    "                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item()/batch_size, kl_div.item()/batch_size))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let see how our network reconstructs our last batch. We display pairs of original digits and reconstructed version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mu, _ = model.encode(x) \n",
    "out = model.decode(mu)\n",
    "x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)\n",
    "out_grid = torchvision.utils.make_grid(x_concat).cpu().data\n",
    "show(out_grid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let see now, how our network generates new samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "        z = torch.randn(16, z_dim).to(device)\n",
    "        out = model.decode(z).view(-1, 1, 28, 28)\n",
    "\n",
    "out_grid = torchvision.utils.make_grid(out).cpu()\n",
    "show(out_grid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not great, but we did not train our network for long... That being said, we have no control of the generated digits. In the rest of this jupyter, we explore ways to generates zeros, ones, twos and so on. As a by product, we show how our VAE will allow us to do clustering.\n",
    "\n",
    "The main idea is to build what we call a Gumbel VAE as described below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gumbel VAE\n",
    "\n",
    "Implement a VAE where you add a categorical variable $c\\in \\{0,\\dots 9\\}$ so that your latent variable model is $p(c,z,x) = p(c)p(z)p_{\\theta}(x|,c,z)$ and your variational posterior is $q_{\\phi}(c|x)q_{\\phi}(z|x)$ as described in this NIPS [paper](https://arxiv.org/abs/1804.00104). Make minimal modifications to previous architecture...\n",
    "\n",
    "The idea is that you incorporates a categorical variable in your latent space. You hope that this categorical variable will encode the class of the digit, so that your network can use it for a better reconstruction. Moreover, if things work as planed, you will then be able to generate digits conditionally to the class, i.e. you can choose the class thanks to the latent categorical variable $c$ and then generate digits from this class.\n",
    "\n",
    "As noticed above, in order to sample random variables while still being able to use backpropagation required us to use the reparameterization trick which is easy for Gaussian random variables. For categorical random variables, the reparameterization trick is explained in this [paper](https://arxiv.org/abs/1611.01144). This is implemented in pytorch thanks to [F.gumbel_softmax](https://pytorch.org/docs/stable/nn.html?highlight=gumbel_softmax#torch.nn.functional.gumbel_softmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_classes = 10\n",
    "\n",
    "class VAE_Gumbel(nn.Module):\n",
    "    def __init__(self, image_size=784, h_dim=400, z_dim=20, n_classes = 10):\n",
    "        super(VAE_Gumbel, self).__init__()\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "        \n",
    "        \n",
    "    def encode(self, x):\n",
    "        #\n",
    "        # your code here / use F.log_softmax\n",
    "        #\n",
    "        \n",
    "    \n",
    "    def reparameterize(self, mu, log_var):\n",
    "        std = torch.exp(log_var/2)\n",
    "        eps = torch.randn_like(std)\n",
    "        return mu + eps * std\n",
    "\n",
    "    def decode(self, z, y_onehot):\n",
    "        #\n",
    "        # your code here / use torch.cat \n",
    "        #\n",
    "        \n",
    "    \n",
    "    def forward(self, x):\n",
    "        #\n",
    "        # your code here / use F.gumbel_softmax\n",
    "        #\n",
    "        \n",
    "\n",
    "model_G = VAE_Gumbel().to(device)\n",
    "optimizer = torch.optim.Adam(model_G.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You need to modify the loss to take into account the categorical random variable with an uniform prior on $\\{0,\\dots 9\\}$, see Appendix A.2 in the NIPS [paper](https://arxiv.org/abs/1804.00104)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_G(model, data_loader=data_loader,num_epochs=num_epochs, beta = 1., verbose=True):\n",
    "    nmi_scores = []\n",
    "    model.train(True)\n",
    "    for epoch in range(num_epochs):\n",
    "        all_labels = []\n",
    "        all_labels_est = []\n",
    "        for i, (x, labels) in enumerate(data_loader):\n",
    "            # Forward pass\n",
    "            x = x.to(device).view(-1, image_size)\n",
    "            #\n",
    "            # your code here\n",
    "            #\n",
    "            \n",
    "            reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction='sum')\n",
    "            #\n",
    "            # your code here\n",
    "            #\n",
    "\n",
    "            # Backprop and optimize\n",
    "            loss = # your code here\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            if verbose:\n",
    "                if (i+1) % 10 == 0:\n",
    "                    print (\"Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}, Entropy: {:.4f}\" \n",
    "                           .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item()/batch_size,\n",
    "                                   kl_div.item()/batch_size, H_cat.item()/batch_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_G(model_G,num_epochs=10,verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,_ = next(iter(data_loader))\n",
    "x = x[:24,:,:,:].to(device)\n",
    "out, _, _, log_p = model_G(x.view(-1, image_size)) \n",
    "x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)\n",
    "out_grid = torchvision.utils.make_grid(x_concat).cpu().data\n",
    "show(out_grid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This was for reconstruction, but we care more about generation. For each category, we are generating 8 samples thanks to the following matrix, so that in the end, we should have on each line only one digit represented."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matrix = np.zeros((8,n_classes))\n",
    "matrix[:,0] = 1\n",
    "final = matrix[:]\n",
    "for i in range(1,n_classes):\n",
    "    final = np.vstack((final,np.roll(matrix,i)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "        z = torch.randn(8*n_classes, z_dim).to(device)\n",
    "        y_onehot = torch.tensor(final).type(torch.FloatTensor).to(device)\n",
    "        out = model_G.decode(z,y_onehot).view(-1, 1, 28, 28)\n",
    "\n",
    "out_grid = torchvision.utils.make_grid(out).cpu()\n",
    "show(out_grid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It does not look like our original idea is working...\n",
    "\n",
    "To check that our network is not using the categorical variable, we can track the [normalized mutual information](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html) between the true labels and the labels 'predicted' by our network (just by taking the category with maximal probability). Change your training loop to return the normalized mutual information (NMI) for each epoch. Plot the curve to check that the NMI is actually decreasing."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to force our network to use the categorical variable, we will change the loss following this ICLR [paper](https://openreview.net/forum?id=Sy2fzU9gl)\n",
    "\n",
    "Implement this change in the training loop and plot the new NMI curve after 10 epochs. For $\\beta = 20$, you should see that NMI increases. But reconstruction starts to be bad and generation is still poor.\n",
    "\n",
    "This is explained in this [paper](https://arxiv.org/abs/1804.03599) and a solution is proposed see Section 5. Implement the solution described in Section 3 equation (7) if the NIPS [paper](https://arxiv.org/abs/1804.00104) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_G = VAE_Gumbel().to(device)\n",
    "optimizer = torch.optim.Adam(model_G.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_G_modified_loss(model, data_loader=data_loader,num_epochs=num_epochs, beta , C_z_fin, C_c_fin, verbose=True):\n",
    "    #\n",
    "    # your code here\n",
    "    #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "        z = torch.randn(8*n_classes, z_dim).to(device)\n",
    "        y_onehot = torch.tensor(final).type(torch.FloatTensor).to(device)\n",
    "        out = model_G.decode(z,y_onehot).view(-1, 1, 28, 28)\n",
    "out_grid = torchvision.utils.make_grid(out).cpu()\n",
    "show(out_grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 1\n",
    "with torch.no_grad():\n",
    "    plt.plot()\n",
    "    z = torch.randn(8, z_dim).to(device)\n",
    "    y_onehot = torch.tensor(np.roll(matrix,i)).type(torch.FloatTensor).to(device)\n",
    "    out = model_G.decode(z,y_onehot).view(-1, 1, 28, 28)\n",
    "    out_grid = torchvision.utils.make_grid(out).cpu()\n",
    "    show(out_grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,_ = next(iter(data_loader))\n",
    "x = x[:24,:,:,:].to(device)\n",
    "out, _, _, log_p = model_G(x.view(-1, image_size)) \n",
    "x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)\n",
    "out_grid = torchvision.utils.make_grid(x_concat).cpu().data\n",
    "show(out_grid)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
